import matplotlib.pyplot as plt
import pandas as pd

# Read CSVs
ai3_ppo = pd.read_csv("logs/AdventureIsland3-Nes/PPO_1/run-AdventureIsland3-Nes_PPO_1-tag-eval_mean_reward.csv")
ai3_a2c = pd.read_csv("logs/AdventureIsland3-Nes/A2C_1/run-AdventureIsland3-Nes_A2C_1-tag-eval_mean_reward.csv")
airstriker_ppo = pd.read_csv("logs/Airstriker-Genesis/PPO_1/run-Airstriker-Genesis_PPO_1-tag-eval_mean_reward.csv")
airstriker_a2c = pd.read_csv("logs/Airstriker-Genesis/A2C_1/run-Airstriker-Genesis_A2C_1-tag-eval_mean_reward.csv")
astro_ppo = pd.read_csv("logs/AstroRoboSasa-Nes/PPO_1/run-AstroRoboSasa-Nes_PPO_1-tag-eval_mean_reward.csv")
astro_a2c = pd.read_csv("logs/AstroRoboSasa-Nes/A2C_1/run-AstroRoboSasa-Nes_A2C_1-tag-eval_mean_reward.csv")
circuscharlie_ppo = pd.read_csv("logs/CircusCharlie-Nes/PPO_1/run-CircusCharlie-Nes_PPO_1-tag-eval_mean_reward.csv")
circuscharlie_a2c = pd.read_csv("logs/CircusCharlie-Nes/A2C_1/run-CircusCharlie-Nes_A2C_1-tag-eval_mean_reward.csv")
pong_ppo = pd.read_csv("logs/Pong-Atari2600/A2C_1/run-Pong-Atari2600_A2C_1-tag-eval_mean_reward.csv")
pong_a2c = pd.read_csv("logs/Pong-Atari2600/PPO_1/run-Pong-Atari2600_PPO_1-tag-eval_mean_reward.csv")
seaquest_ppo = pd.read_csv("logs/Seaquest-Atari2600/PPO_1/run-Seaquest-Atari2600_PPO_1-tag-eval_mean_reward.csv")
seaquest_a2c = pd.read_csv("logs/Seaquest-Atari2600/A2C_1/run-Seaquest-Atari2600_A2C_1-tag-eval_mean_reward.csv")

# Create smoothened graphs
SMOOTHEN = 0.95
ai3_ppo["Value Smooth"] = ai3_ppo["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
ai3_a2c["Value Smooth"] = ai3_a2c["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
airstriker_ppo["Value Smooth"] = airstriker_ppo["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
airstriker_a2c["Value Smooth"] = airstriker_a2c["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
astro_ppo["Value Smooth"] = astro_ppo["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
astro_a2c["Value Smooth"] = astro_a2c["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
circuscharlie_ppo["Value Smooth"] = circuscharlie_ppo["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
circuscharlie_a2c["Value Smooth"] = circuscharlie_a2c["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
pong_ppo["Value Smooth"] = pong_ppo["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
pong_a2c["Value Smooth"] = pong_a2c["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
seaquest_ppo["Value Smooth"] = seaquest_ppo["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()
seaquest_a2c["Value Smooth"] = seaquest_a2c["Value"].ewm(alpha=(1 - SMOOTHEN)).mean()

# Start Plotting
fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3, figsize=(20, 10))

# ~~~~~ Adventure Island 3 Nes ~~~~~ #
# PPO
ppo_value_plot = ax1.plot(ai3_ppo["Step"], ai3_ppo["Value"], lw=1, alpha=0.25)
ax1.plot(ai3_ppo["Step"], ai3_ppo["Value Smooth"], lw=1, c=ppo_value_plot[-1].get_color(), label="PPO")
# A2C
a2c_value_plot = ax1.plot(ai3_a2c["Step"], ai3_a2c["Value"], lw=1, alpha=0.25)
ax1.plot(ai3_a2c["Step"], ai3_a2c["Value Smooth"], lw=1, c=a2c_value_plot[-1].get_color(), label="A2C")
# Human
ax1.axhline(5000, 0.05, 0.95, ls="--", lw=3, c=next(ax1._get_lines.prop_cycler)["color"], label="Human")
# SymReL
symrel_value_plot = ax1.axhline(
    3250, 0.05, 0.95, lw=2, c=next(ax1._get_lines.prop_cycler)["color"], label="Distilled Symbolic Rule"
)
# Info
ax1.set_title("AdventureIsland3-Nes", fontsize=16)
ax1.set_ylabel("Reward", fontsize=14)

# ~~~~~ Airstriker Genesis ~~~~~ #
# PPO
value_plot = ax2.plot(airstriker_ppo["Step"], airstriker_ppo["Value"], lw=1, alpha=0.25)
ax2.plot(airstriker_ppo["Step"], airstriker_ppo["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="PPO")
# A2C
value_plot = ax2.plot(airstriker_a2c["Step"], airstriker_a2c["Value"], lw=1, alpha=0.25)
ax2.plot(airstriker_a2c["Step"], airstriker_a2c["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="A2C")
# Human
ax2.axhline(520, 0.05, 0.95, ls="--", lw=3, c=next(ax2._get_lines.prop_cycler)["color"], label="Human")
# SymReL
ax2.axhline(260, 0.05, 0.95, lw=2, c=next(ax2._get_lines.prop_cycler)["color"], label="Distilled Symbolic Rule")
# Info
ax2.set_title("Airstriker-Genesis", fontsize=16)
# ax2.legend()

# ~~~~~ Astro Robo Sasa Nes ~~~~~ #
# PPO
value_plot = ax3.plot(astro_ppo["Step"], astro_ppo["Value"], lw=1, alpha=0.25)
ax3.plot(astro_ppo["Step"], astro_ppo["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="PPO")
# A2C
value_plot = ax3.plot(astro_a2c["Step"], astro_a2c["Value"], lw=1, alpha=0.25)
ax3.plot(astro_a2c["Step"], astro_a2c["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="A2C")
# Human
ax3.axhline(1800, 0.05, 0.95, ls="--", lw=3, c=next(ax3._get_lines.prop_cycler)["color"], label="Human")
# SymReL
ax3.axhline(1800, 0.05, 0.95, lw=2, c=next(ax3._get_lines.prop_cycler)["color"], label="Distilled Symbolic Rule")
# Info
ax3.set_title("AstroRoboSasa-Nes", fontsize=16)
# ax3.legend()

# ~~~~~ Circus Charlie Nes ~~~~~ #
# PPO
value_plot = ax4.plot(circuscharlie_ppo["Step"], circuscharlie_ppo["Value"], lw=1, alpha=0.25)
ax4.plot(circuscharlie_ppo["Step"], circuscharlie_ppo["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="PPO")
# A2C
value_plot = ax4.plot(circuscharlie_a2c["Step"], circuscharlie_a2c["Value"], lw=1, alpha=0.25)
ax4.plot(circuscharlie_a2c["Step"], circuscharlie_a2c["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="A2C")
# Human
ax4.axhline(7530, 0.05, 0.95, ls="--", lw=3, c=next(ax4._get_lines.prop_cycler)["color"], label="Human")
# SymReL
ax4.axhline(7690, 0.05, 0.95, lw=2, c=next(ax4._get_lines.prop_cycler)["color"], label="Distilled Symbolic Rule")
# Info
ax4.set_title("CircusCharlie-Nes", fontsize=16)
ax4.set_xlabel("Total Timesteps of Training Episodes", fontsize=14)
ax4.set_ylabel("Reward", fontsize=14)
# ax4.legend()

# ~~~~~ Pong Atari 2600 ~~~~~ #
# PPO
ppo_value_plot = ax5.plot(pong_ppo["Step"], pong_ppo["Value"], lw=1, alpha=0.25)
ax5.plot(pong_ppo["Step"], pong_ppo["Value Smooth"], lw=1, c=ppo_value_plot[-1].get_color(), label="PPO")
# A2C
a2c_value_plot = ax5.plot(pong_a2c["Step"], pong_a2c["Value"], lw=1, alpha=0.25)
ax5.plot(pong_a2c["Step"], pong_a2c["Value Smooth"], lw=1, c=a2c_value_plot[-1].get_color(), label="A2C")
# Human
ax5.axhline(6, 0.05, 0.95, ls="--", lw=3, c=next(ax5._get_lines.prop_cycler)["color"], label="Human")
# SymReL
ax5.axhline(17, 0.05, 0.95, lw=2, c=next(ax5._get_lines.prop_cycler)["color"], label="Distilled Symbolic Rule")
# Info
ax5.set_title("Pong-Atari2600", fontsize=16)
ax5.set_xlabel("Total Timesteps of Training Episodes", fontsize=14)

# ~~~~~ Seaquest Atari 2600 ~~~~~ #
# PPO
ax6.semilogy()
value_plot = ax6.plot(seaquest_ppo["Step"], seaquest_ppo["Value"], lw=1, alpha=0.25)
ax6.plot(seaquest_ppo["Step"], seaquest_ppo["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="PPO")
# A2C
value_plot = ax6.plot(seaquest_a2c["Step"], seaquest_a2c["Value"], lw=1, alpha=0.25)
ax6.plot(seaquest_a2c["Step"], seaquest_a2c["Value Smooth"], lw=1, c=value_plot[-1].get_color(), label="A2C")
# Human
ax6.axhline(7530, 0.05, 0.95, ls="--", lw=3, c=next(ax6._get_lines.prop_cycler)["color"], label="Human")
# SymReL
ax6.axhline(940, 0.05, 0.95, lw=2, c=next(ax6._get_lines.prop_cycler)["color"], label="Distilled Symbolic Rule")
# Info
ax6.set_title("Seaquest-Atari2600", fontsize=16)
ax6.set_xlabel("Total Timesteps of Training Episodes", fontsize=14)

# Show and save
plt.legend(bbox_to_anchor=(-1.2, -0.45, 1.2, 0.3), fontsize=14, ncol=4)
fig.savefig("performance-plot_1.pdf", bbox_inches="tight")
